/*
    Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// The FITS I/O functions of the nbody_data_grid class.
/// \ingroup makegrid

#include "config.h"
#include "sphgrid.h"
#include "grid-fits.h"
#include "CCfits/CCfits"
#include "blitz-fits.h"
#include "counter.h"
#include <hpm.h> 
#include "fits-utilities.h"
#include "snapshot.h"

using namespace CCfits;
using namespace std;
using namespace blitz;

mcrx::nbody_data_grid::nbody_data_grid (CCfits::ExtHDU& input, 
					const vec3d& translate_origin) :
  adaptive_grid<nbody_data_cell> (input, 0, translate_origin),
  ctr (1), n_threads (0), work_chunk_levels (0) 
{
  input.readKey("lengthunit", length_unit);
};

/** Loads a snapshot file and creates the adaptive refinement grid.
    If the info pointer is set, information about grid building is
    written to that HDU. */
bool mcrx::nbody_data_grid::load_snapshot (const Snapshot& snap,
					   const int ml, 
					   const T_float size_fudge,
					   const tolerance_checker& tc,
					   HDU* info,
					   bool use_SED)
{
  max_level = ml;
  size_factor = size_fudge;
  tol_checker_ = auto_ptr<tolerance_checker>(new tolerance_checker(tc));

  cout << "  Building grid from snapshot \'" << snap.name()
       << "\'" << endl;
  // open the file to figure out how many objects are in it

  if (use_SED)
    cout << "   Emission information will be used" << endl;
  else
    cout << "   Emission information will not be used" << endl;

  mass_unit = snap.units().get("mass");
  time_unit = snap.units().get("time");
  length_unit = snap.units().get("length");
  SFR_unit = mass_unit+"/"+time_unit;
  temp_unit = snap.units().get("temperature");

  int L_lambda_width = 0;
  T_float m_g_tot = 0, L_bol_tot = 0, SFR_tot = 0;
  T_float metals_tot = 0; // this is the total MASS of metals

  // The only particles that are being put in the list are the GAS
  // particles. (We will at least for now explicitly avoid the
  // possibility of having the emission in the grid.)

  // syntactic convenience
  const gas_particle_set& s 
    = dynamic_cast<const gas_particle_set&>(*snap.sets()[Snapshot::gas]);
  
  for (int i=0; i< s.n(); ++i) {
      
    m_g_tot+= s.m(i);
    SFR_tot+= s.sfr(i);
    metals_tot+= s.m(i)*s.z(i);

    // make gas particle (mass-weighted temperature)
    nbody_data d(s.m(i), s.m(i)*s.vel(i), 0, s.sfr(i), s.m(i)*s.z(i), 0, vec3d(0,0,0),
		 0, 0, 0,
		 s.temp(i)*s.m(i), s.teff(i)*s.m(i)); 
    // velocity of particle isn't used here, it's the momentum in the nbody_data
    T_particle* p = new T_particle (s.id(i), s.pos(i), vec3d(0,0,0), s.h(i), d);
    particle_list_.push_back(p );
    
  }

  cout << "Total number of particles used for refinement: " <<
    particle_list_.size() << endl;
  
  {
    // setting L_lambda in data_zero to empty is what inhibits future
    // use of L_lambda in the grid making
    array_1 zero (use_SED?  L_lambda_width: 0);
    zero = 0;
    data_zero = nbody_data_cell (0, vec3d(0,0,0), 0, 0, 0, 0, vec3d(0,0,0), 0,0,0,0,0,zero);
  }
  
  assert (!use_SED || data_zero.L_lambda.size()> 0); 

  // Scale absolute tolerance by total bolometric luminosity (note
  // that this will be incorrect if the grid volume only covers part
  // of the snapshot)

  tol_checker_->set_L_bol(L_bol_tot);
  tol_checker_->print_tolerances();

  cout << "  Building grid" << endl;
  
  // All the particles are now in the list
  // I think we can now just call recursive_refine
  build_grid ();

  cout << "  Done, cleaning up."  << endl;
  
  // now make sure we delete the particles
  for (vector<T_particle*>::iterator i = particle_list_.begin();
       i != particle_list_.end(); ++i)
    delete *i;
  particle_list_.clear();

  // print creation refinement statistics
  vector<int> stats = statistics ();
  stats.resize(max_level+1);
  cout << "  Grid refinement statistics:\n";
  if (info)
    info->writeComment("Grid refinement statistics");
  for (int i = 0; i <= max_level; ++i) {
    ostringstream s;
    s << i << "    " << stats [i] << ", " << creation_stats [i] 
      << " created" ;
    cout << "    " << s.str () << endl;
    if (info) 
      info->writeComment(s.str ()) ;
  } 
  {
    ostringstream s;
    s << "Total number of cells: " << n_cells ();
    cout << "  " << s.str () << endl;
    if (info)
      info->writeComment(s.str ());
  }

  // Calculate total grid quantities
#ifdef MCRX_DEBUG_LEVEL
  ofstream ooo("cell_optical_depths.txt");
#endif
  const_iterator c = begin (), e=end();
  total_quantities=*c->data(); 
  ++c;
  int n_over=0;
  T_float max_over=0;
  for (; c != e; ++c) {
    const T_float tau=c->data()->m_metals()/pow(c.volume(),2./3);
    DEBUG(1,ooo << tau*3e-5 << '\t' << c.volume() << '\n';);
    if(tau>tol_checker_->max_metal_column()) {
      max_over = std::max(max_over, tau/tol_checker_->max_metal_column());
      ++n_over;
    }
    total_quantities += *c->data(); 
  }
  if (n_over>0) {
    cout << "WARNING: " << n_over << " grid cells exceed tau tolerance" << endl;
    cout << "  Largest excess " << max_over << endl;
  }

  cout
    << "  Total snapshot quantities: " << m_g_tot << mass_unit
    << ", " << L_bol_tot << L_bol_unit 
    << ", " << SFR_tot << SFR_unit 
    << ", " << metals_tot << mass_unit 
    << "\n  Total grid quantities: " << total_quantities.m_g_ << mass_unit
    << ", " << total_quantities.L_bol() << L_bol_unit 
    << ", " << total_quantities.SFR << SFR_unit 
    << ", " << total_quantities.m_metals() << mass_unit << endl;
  if (info) {
    info->addKey("M_g_tot", total_quantities.m_g(), "[" + mass_unit +
		 "] Total gas mass in all cells");
    info->addKey("L_bol_tot", total_quantities.L_bol(), "[" + L_bol_unit +
		 "] Total bolometric luminosity of all cells");
    info->addKey("SFR_tot", total_quantities.SFR, "[" + SFR_unit +
		 "] Total star formation rate of all cells");
    info->addKey("M_metals_tot", total_quantities.m_metals(), "[" + mass_unit +
		 "] Total mass of metals in all cells");
  }    

  return true;
}


/** Saves the grid cell data to a FITS HDU.  (The grid structure is
    written by the adaptive_grid class.)  If specified, the integrated
    SED of the entire grid is also written to an HDU.  */
void mcrx::nbody_data_grid::save_data (FITS& file, const std::string& hdu,
				       const std::string& iq_hdu) const
{
  cout << "Saving grid data in HDU " << hdu << endl;

  cout << " collecting grid data" << endl;
  //  We will collect all the data first, and then write.
  const bool save_SED = data_zero.L_lambda.size() > 0;
  const int n = n_cells (); // unfortunately, we have to...

  vector<T_float> m_g, L_bol;
  m_g.reserve(n);
  vector<T_float> SFR, m_metals, volume,m_s,m_m_s,age_m,age_l,gas_temp,gas_teff;
  vector<vec3d> p_g, p_s;
  p_g.reserve(n);
  SFR.reserve(n);
  m_metals.reserve(n);
  volume.reserve(n);
  gas_temp.reserve(n);
  gas_teff.reserve(n);
  if (save_SED) {
    L_bol.reserve(n);
    m_s.reserve(n);
    p_s.reserve(n);
    m_m_s.reserve(n);
    age_m.reserve(n);
    age_l.reserve(n);
  }
  array_2 L_lambda (n, data_zero.L_lambda.size());
  
  int i = 0;
  for (const_iterator c = begin (); c != end (); ++c) {
    m_g.push_back(c->data()->m_g() );
    p_g.push_back(c->data()->p_g() );
    SFR.push_back(c->data()->SFR );
    m_metals.push_back(c->data()->m_metals() );
    gas_temp.push_back(c->data()->gas_temp_m );
    gas_teff.push_back(c->data()->gas_teff_m );
    volume.push_back(c.volume () );
    if (save_SED) {
      L_bol.push_back(c->data()->L_bol() );
      m_s.push_back(c->data()->m_s_ );
      p_s.push_back(c->data()->p_s_ );
      m_m_s.push_back(c->data()->m_m_s );
      age_m.push_back(c->data()->age_m );
      age_l.push_back(c->data()->age_l );
      L_lambda (i, Range::all ()) = c->data()->L_lambda;
    }
    ++i;
  }

  // also write the integrated SED of the entire grid
  if ((iq_hdu != "") && save_SED) {
    cout << "Writing integrated grid SED to HDU " << iq_hdu << endl;
    Table* output = file.addTable(string (iq_hdu) , 0);
    output->writeComment("This HDU contains the integrated SED of all cells in the grid");
    output->addColumn(Tdouble, "L_lambda", 1, L_lambda_unit );

    array_1 integrated_L_lambda (sum (L_lambda (blitz::tensor::j,
						blitz::tensor::i),
				      blitz::tensor::j));
    write (output->column("L_lambda" ), integrated_L_lambda, 1 );
  }

  Table* output = file.addTable(string (hdu) , 0);

  output->addKey("M_g_tot", total_quantities.m_g(), "[" + mass_unit +
		 "] Total gas mass in all cells");
  output->addKey("SFR_tot", total_quantities.SFR, "[" + SFR_unit +
		 "] Total star formation rate of all cells");
  output->addKey("timeunit", time_unit, "Time unit is "+time_unit);
  output->addKey("tempunit", temp_unit, "Temperature unit is "+time_unit);
  
  output->addColumn(Tdouble, "mass_gas", 1, mass_unit );
  output->addColumn(Tdouble, "p_gas", 3, mass_unit+"*"+length_unit+"/"+time_unit);
  output->addColumn(Tdouble, "SFR", 1, SFR_unit );
  output->addColumn(Tdouble, "mass_metals", 1, mass_unit );
  output->addColumn(Tdouble, "gas_temp_m", 1, temp_unit+"*"+mass_unit );
  output->addColumn(Tdouble, "gas_teff_m", 1, temp_unit+"*"+mass_unit );
  output->addColumn(Tdouble, "cell_volume", 1, length_unit + "^3" );

  Column& c_mg = output->column("mass_gas");
  Column& c_pg = output->column("p_gas");
  Column& c_sfr = output->column("SFR");
  Column& c_mm = output->column("mass_metals");
  Column& c_temp = output->column("gas_temp_m");
  Column& c_teff = output->column("gas_teff_m");
  Column& c_vol = output->column("cell_volume");

  if (save_SED) {
    output->addKey("L_bol_tot", total_quantities.L_bol(), "[" + L_bol_unit +
		   "] Total bolometric luminosity of all cells");
    output->addKey("M_metals_tot", total_quantities.m_metals(), "[" + mass_unit +
		   "] Total metal mass in all cells");

    output->addColumn(Tdouble, "L_bol", 1, L_bol_unit );
    output->addColumn(Tdouble, "mass_stars", 1, mass_unit );
    output->addColumn(Tdouble, "p_stars", 3, mass_unit+"*"+length_unit+"/"+time_unit);
    output->addColumn(Tdouble, "mass_stellar_metals", 1, mass_unit );
    output->addColumn(Tdouble, "age_m", 1, time_unit+"*"+mass_unit );
    output->addColumn(Tdouble, "age_l", 1, time_unit+"*"+mass_unit );
    
    // NOTE: we write L_lambda as a float column of log.  Because 0
    // values are complicated, we add a very small amount to every
    // number
    L_lambda =log10(L_lambda + blitz::tiny (T_float ()));
    output->addKey("logflux", true,
		   "Column L_lambda values are log (L_lambda)");
    output->addColumn(Tfloat, "L_lambda", L_lambda.columns(), L_lambda_unit );
  }

  cout << " writing FITS file\n";

  int n_chunk = output->getRowsize();
  cout << " chunk size is " << n_chunk << endl;
  int start_row = 1;
  int j = 0;
  counter c (144);
  while (j < n) {
    const int this_chunk = (n- j > n_chunk)?  n_chunk: n-j;
    const int current_row = j+ start_row;
    make_onedva(m_g).write_column(*output, c_mg, current_row, j, this_chunk);
    write(c_pg, p_g, current_row, j, this_chunk);
    make_onedva(SFR).write_column(*output, c_sfr, current_row, j, this_chunk);
    make_onedva(m_metals).write_column(*output, c_mm, 
				       current_row, j, this_chunk);
    make_onedva(gas_temp).write_column(*output, c_temp, 
				       current_row, j, this_chunk);
    make_onedva(gas_teff).write_column(*output, c_teff, 
				       current_row, j, this_chunk);
    make_onedva(volume).write_column(*output, c_vol, 
				     current_row, j, this_chunk);

    if (save_SED) {
      output->column("L_bol").write( &L_bol [j], this_chunk, current_row );
      output->column("mass_stars").write( &m_s [j],
					  this_chunk, current_row );
      output->column("p_stars").write( reinterpret_cast<T_float*>(&p_s [j]),
				       3*this_chunk, current_row );
      output->column("mass_stellar_metals").write( &m_m_s [j],
						   this_chunk, current_row );
      output->column("age_m").write( &age_m [j],
				     this_chunk, current_row );
      output->column("age_l").write( &age_l [j],
				     this_chunk, current_row );
      write (output->column("L_lambda" ),
	     L_lambda (Range (j, j + this_chunk-1 ),
		       Range::all ()),  
	     current_row );
    }
    j+= this_chunk;
    c+=this_chunk;
  }
}
 
void mcrx::nbody_data_grid::load_data (ExtHDU& hdu)
{
  // The grid structure already has to have been loaded upon
  // constructor of the grid_base object.

  // First read the data that is always there

  cout << "Loading grid data from FITS file" << endl;

  Column& c_m_g = hdu.column("mass_gas");
  Column& c_SFR = hdu.column("SFR");
  Column& c_metals = hdu.column("mass_metals");
  // these two are temp*mass, not just temperature
  Column& c_temp_m = hdu.column("gas_temp_m");
  Column& c_teff_m = hdu.column("gas_teff_m");

  // length unit is read when reading gridstructure
  mass_unit = c_m_g.unit();
  SFR_unit = c_SFR.unit();
  assert(c_teff_m.unit()==c_temp_m.unit());
  temp_unit = "K*"+mass_unit;
  const T_float tempconv = units::convert(c_temp_m.unit(),temp_unit);

  const int n = n_cells ();
  if (c_m_g.rows() != n) {
    cerr << "nbody_data_grid::load_data: Error: HDU does not contain the expected number of rows" << endl;
    throw 0;
  }

  mass_unit = c_m_g.unit();
  hdu.readKey("timeunit",time_unit);
  assert(time_unit!="");
  assert(mass_unit!="");
  SFR_unit = c_SFR.unit(); 

  vector<T_float> m_g;
  vector<vec3d> p_g, p_s;
  vector<T_float> L_bol;
  vector<T_float> SFR;
  vector<T_float> metals;
  vector<T_float> m_s,m_m_s,age_m,age_l,gas_temp_m,gas_teff_m;
  array_2 L_lambda;

  c_m_g.read(m_g, 1, c_m_g.rows());
  p_g.assign(m_g.size(), vec3d(0,0,0));
  cerr << "Warning: Setting momentum of gas cells to zero when reading" << endl;
  c_SFR.read(SFR, 1, c_SFR.rows() ); 
  c_metals.read(metals, 1, c_metals.rows() ); 
  c_temp_m.read(gas_temp_m, 1, c_temp_m.rows() ); 
  c_teff_m.read(gas_teff_m, 1, c_teff_m.rows() ); 

  // try to read the lum info (we never really use this, so stale warning!)
  try {
    Column& c_ms = hdu.column("mass_stars");
    Column& c_mms = hdu.column("mass_stellar_metals");
    Column& c_agem = hdu.column("age_m");
    Column& c_agel = hdu.column("age_l");
    Column& c_l_bol = hdu.column("L_bol");
    Column& c_l_lambda = hdu.column("L_lambda");

    L_bol_unit = c_l_bol.unit();
    L_lambda_unit = c_l_lambda.unit();

    c_ms.read(m_s, 1, c_ms.rows() ); 
    c_mms.read(m_m_s, 1, c_mms.rows() ); 
    c_agem.read(age_m, 1, c_agem.rows() ); 
    c_agel.read(age_l, 1, c_agel.rows() ); 
    c_l_bol.read(L_bol, 1, c_l_bol.rows() );
    read (c_l_lambda, L_lambda);

    cout << "\tloading grid emission data" << endl;

    // look for logarithmic flux
    try {
      bool logarithmic= false;
      hdu.readKey("logflux", logarithmic);
      if (logarithmic)
	L_lambda  = pow (10., L_lambda);
    }
    catch (...) {}

    int i = 0;
    for (iterator c = begin (); c != end(); ++c, ++i) {
      if(!c->data()) {
	c->set_data(new T_data() );
      }

      T_data d (m_g [i], p_g [i], L_bol [i], SFR [i], metals [i],
		m_s[i], p_s [i], m_m_s[i], age_m[i], age_l[i], 
		gas_temp_m[i]*tempconv, 
		gas_teff_m[i]*tempconv,
		L_lambda (i, Range::all ()));
      
      *c->data() = d;
    }

  }
  catch (...) {
    //if we end up here we failed to read the lum info and proceed without it
    int i = 0;
    for (iterator c = begin (); c != end(); ++c, ++i) {
      // if the cell has no data member, default-construct one before assigning
      if(!c->data()) {
	c->set_data(new T_data() );
      }
  
      T_data d (m_g [i], p_g [i], 0, SFR [i], metals [i],
		0, vec3d (0,0,0), 0, 0, 0,
		gas_temp_m[i], gas_teff_m[i]);
      *c->data() = d;
    }
  }
}
